Skip to content

SDPA decode perf improvements for qwen-3.5-35B-A3B#18759

Merged
digantdesai merged 5 commits intomainfrom
digantdesai/sdpa-bench-and-perf-stats
Apr 15, 2026
Merged

SDPA decode perf improvements for qwen-3.5-35B-A3B#18759
digantdesai merged 5 commits intomainfrom
digantdesai/sdpa-bench-and-perf-stats

Conversation

@digantdesai
Copy link
Copy Markdown
Contributor

@digantdesai digantdesai commented Apr 8, 2026

Performance Improvements for SDPA

Improves SDPA performance for decode sequences where $L_q = 1$.

Benchmark: qwen3.5-35B-A3B

  • Config: Avg of 3 runs on A100.

Decode Performance (tok/s)

Prompt Decode Len Baseline Split-K Speedup
P2 (1 tok) 16 86.3 105.3 +22%
P2 64 87.1 108.4 +24%
P2 256 89.2 108.3 +21%
P2 1024 89.4 108.1 +21%
P15 (15 tok) 16 85.9 102.8 +20%
P15 64 85.1 104.7 +23%
P15 256 88.4 108.8 +23%
P15 1024 90.0 107.2 +19%
P59 (59 tok) 16 86.8 96.1 +11%
P59 64 89.5 99.8 +12%
P59 256 88.9 108.5 +22%
P59 1024 90.0 108.6 +21%
P120 (143 tok) 16 87.5 105.0 +20%
P120 64 88.8 107.7 +21%
P120 256 90.3 107.6 +19%
P120 1024 89.4 109.3 +22%
P1000 (1694 tok) 16 86.4 103.2 +19%
P1000 64 89.2 106.7 +20%
P1000 256 90.2 108.0 +20%
P1000 1024 89.7 108.0 +20%

Prefill Performance (tok/s)

Prompt Baseline Split-K Delta
P2 (1 tok) 19.4 19.3 ~same
P15 (15 tok) 192.8 191.6 ~same
P59 (59 tok) 390.2 368.1 -6%
P120 (143 tok) 512.4 481.9 -6%
P1000 (1694 tok) 585.6 591.4 +1%

Note: Prefill averaged across all 4 decode lengths per prompt since prefill is independent of decode length.


Summary

  • Decode: Split-K delivers +20% average (88.6 → 106.5 tok/s)
  • Prefill: similar between variants (both use tiled SDPA)
  • Quality: Verified identical at temperature=0

(~25x speedup at the SDPA op level, for ~10.2K = 1024 tokens x 10 layers, calls we saw 5.3sec to 209ms speedup)

Implementation Details

  • Max Context Length: 4K
  • Kernel Constraints:
    • Baseline: Updates example input shapes to remove the 64-token cap.
    • Prefill: Baseline and Split-K should be equivalent for prefill (both use _sdpa_fwd_kernel_m64).

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 8, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18759

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

⏳ No Failures, 114 Pending

As of commit d5209fc with merge base 841181e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 8, 2026
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 8, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@digantdesai digantdesai force-pushed the digantdesai/sdpa-bench-and-perf-stats branch from 2cb04c3 to febc419 Compare April 8, 2026 04:12
@digantdesai digantdesai changed the title [aoti-cuda] Add SDPA benchmarking script with qwen-3.5-35B-A3B shapes SDPA decode perf improvements for qwen-3.5-35B-A3B Apr 9, 2026
@digantdesai digantdesai marked this pull request as ready for review April 9, 2026 17:44
@digantdesai digantdesai requested a review from lucylq as a code owner April 9, 2026 17:44
Copilot AI review requested due to automatic review settings April 9, 2026 17:44
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves ExecuTorch CUDA SDPA decode performance for the common decode case where Lq = 1 (e.g., Qwen3.5 MoE generation), by introducing a Split-K “flash-decoding” Triton path and dispatching to it at runtime.

Changes:

  • Add a Split-K decode SDPA Triton kernel (sdpa_decode_splitk) plus a reduction kernel to improve occupancy when L_q == 1.
  • Update the Qwen3.5 MoE attention path to dispatch between Split-K (decode) and tiled SDPA (prefill) via torch.cond.
  • Add correctness tests and a benchmark script for SDPA decode shapes; update export example shapes to avoid overly-small AOTI shape specialization.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
examples/models/qwen3_5_moe/model.py Switch attention to Triton SDPA and add decode-time Split-K dispatch via torch.cond.
examples/models/qwen3_5_moe/main.cpp Plumb a stats callback into generation and print throughput/timing breakdown.
examples/models/qwen3_5_moe/export.py Use a max-length example sequence to prevent AOTI from baking in too-small intermediate buffers.
backends/cuda/triton/kernels/sdpa.py Implement Split-K decode kernel + reduction and expose sdpa_decode_splitk.
backends/cuda/triton/kernels/init.py Export sdpa_decode_splitk from the kernels package.
backends/cuda/tests/test_triton_sdpa_splitk.py Add CUDA BF16 unit tests validating Split-K correctness vs PyTorch SDPA reference.
backends/cuda/benchmarks/benchmark_sdpa.py Add a benchmark script comparing Triton SDPA/Split-K vs PyTorch SDPA backends.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1369 to +1390
@triton_op("triton::sdpa_decode_splitk", mutates_args={})
def sdpa_decode_splitk(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float = 0.0,
enable_gqa: bool = False,
) -> torch.Tensor:
"""Split-K flash-decoding SDPA for L_q=1 (decode step).

Signature mirrors sdpa() for drop-in use with torch.cond dispatch.
enable_gqa is accepted but ignored — GQA is handled natively via
H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1.
"""
B, H_q, L_q, D = query.shape
_, H_kv, L_kv, _ = key.shape

out = torch.empty((B, H_q, L_q, D), device=query.device, dtype=query.dtype)

Copy link

Copilot AI Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sdpa_decode_splitk() launches kernels that assume CUDA + bfloat16 inputs (and the reduce kernel stores bfloat16 unconditionally), but unlike sdpa() it never calls _validate_sdpa_inputs(). This can lead to hard-to-debug failures or incorrect behavior if the op is called with non-bfloat16 or non-CUDA tensors. Consider reusing _validate_sdpa_inputs(query, key, value, dropout_p, enable_gqa) at the top of sdpa_decode_splitk (and keep the existing feature/shape checks).

Copilot uses AI. Check for mistakes.
Comment on lines +1339 to +1343
HAS_MASK=HAS_MASK,
HEAD_DIM=D,
NUM_GROUPS=num_groups,
BLOCK_G=_next_power_of_2(num_groups),
)
Copy link

Copilot AI Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BLOCK_G is derived via _next_power_of_2(num_groups), but _next_power_of_2() is documented as clamping to [16, 256] for head-dim tiling. Using it for num_groups forces a minimum of 16 (wasted work for MHA/low-group configs) and would be incorrect if num_groups > 256 (since BLOCK_G would clamp to 256 and not cover all groups). Consider a separate next_pow2 helper for group counts (no min clamp) and/or add a clear runtime check that num_groups <= 256 if that’s an intended constraint.

Copilot uses AI. Check for mistakes.
device_name = torch.cuda.get_device_name()
print()
print("=" * 100)
print(f"SDPA Benchmark Qwen3.5-35B-A4B — {scenario}")
Copy link

Copilot AI Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The benchmark header prints "Qwen3.5-35B-A4B", but this benchmark (and the PR description) refers to the A3B variant. This looks like a typo and can confuse readers when comparing numbers; consider correcting the printed model name.

Suggested change
print(f"SDPA Benchmark Qwen3.5-35B-A4B{scenario}")
print(f"SDPA Benchmark Qwen3.5-35B-A3B{scenario}")

Copilot uses AI. Check for mistakes.
@digantdesai
Copy link
Copy Markdown
Contributor Author

@claude review this PR.

@claude
Copy link
Copy Markdown

claude Bot commented Apr 9, 2026

Claude finished @digantdesai's task in 2m 21s —— View job


PR Review: SDPA decode perf improvements for qwen-3.5-35B-A3B

  • Read all changed files
  • Review Split-K decode kernel implementation (sdpa.py)
  • Review tests (test_triton_sdpa_splitk.py)
  • Review benchmark script (benchmark_sdpa.py)
  • Review model/export changes (model.py, export.py, main.cpp)
  • Post review feedback

Overall this is a well-structured PR. The Split-K flash-decoding approach is the right strategy for L_q=1 with low H_kv — the occupancy problem is real and the 1.45x end-to-end speedup is meaningful. The kernel implementation, torch.cond dispatch, and test coverage are solid. A few issues below, roughly ordered by severity.


Issues

1. Missing input validation in sdpa_decode_splitksdpa.py:1386-1430

sdpa() calls _validate_sdpa_inputs() (checking CUDA device, bf16 dtype, 4D shape, dropout=0) before any kernel work. sdpa_decode_splitk() does its own partial validation (dropout, is_causal, L_q, H_q divisibility, pow2 D) but never checks that tensors are CUDA or bf16. Passing float32 or CPU tensors would reach the Triton kernel and produce silent wrong results or cryptic errors.

Call _validate_sdpa_inputs(query, key, value, dropout_p, enable_gqa=True) at the top, then keep the Split-K-specific checks (L_q==1, pow2 D) afterward. The existing dropout_p != 0 check would then become redundant and can be removed.

Fix this →

2. _next_power_of_2 used for BLOCK_G has wrong domainsdpa.py:1342

_next_power_of_2() clamps to [16, 256], designed for HEAD_DIM tiling. When used for BLOCK_G = _next_power_of_2(num_groups), MHA configs (num_groups=1) get BLOCK_G=16, which means the split-K kernel allocates 16-wide vectors/accumulators but only uses 1 element, wasting 15/16 of register resources and compute. For the Qwen3.5 config (num_groups=8), BLOCK_G=16 — still 50% waste.

This doesn't cause incorrect results (the g_valid mask handles it), but it's a missed optimization opportunity. A separate helper without the min=16 clamp would give BLOCK_G=1 for MHA, BLOCK_G=8 for Qwen3.5, etc.

Fix this →

3. Reduce kernel output addressing assumes B=1sdpa.py:1259-1262

The reduce kernel uses pid * stride_oh to compute the output pointer, where grid_reduce = (B * H_q,). For B>1, the output is [B, H_q, 1, D] and stride_oh = 1 * D (for contiguous layout). The correct offset for pid = b * H_q + h_q is b * stride_ob + h_q * stride_oh + 0 * stride_om. Since stride_ob = H_q * 1 * D = H_q * stride_oh, writing pid * stride_oh is equivalent to (b * H_q + h_q) * stride_oh = b * stride_ob + h_q * stride_oh, which is correct only when stride_ob == H_q * stride_oh (i.e., the tensor is contiguous in the B dimension). This holds for freshly-allocated contiguous tensors, but it's fragile and the comment at line 1259 should note this assumption more explicitly.

The same pattern applies to the partial buffers — pid * stride_op_h assumes contiguous layout of O_partial[splits, B, H_q, D]. Since these are allocated in _launch_decode_splitk and guaranteed contiguous, this is safe, but worth a comment.

4. Benchmark typo: "A4B" should be "A3B"benchmark_sdpa.py:181

print(f"SDPA Benchmark Qwen3.5-35B-A4B — {scenario}")

Should be "A3B" to match the model variant. Copilot flagged this too.

Fix this →


Suggestions (non-blocking)

5. Partial accumulator stores unscaled acc rather than acc / l_isdpa.py:1186-1204

The split-K kernel stores the raw unnormalized accumulator acc and separate m_i / l_i. The reduce kernel then rescales and normalizes. This is the standard flash-decoding approach and is correct. However, storing acc without the l_i normalization means partial outputs can have very large magnitudes (especially with long chunks), which might lose precision in float32 for extreme cases. Not a problem for the 4K max context here, but worth keeping in mind if this is extended to longer contexts.

6. num_splits heuristicsdpa.py:1283

num_splits = min(max(triton.cdiv(L_kv, 256), 1), 128)

The choice of 256 tokens per split and max 128 splits is reasonable for A100 with ~108 SMs, but this could benefit from being SM-count-aware (e.g., torch.cuda.get_device_properties().multi_processor_count). For smaller GPUs (e.g., A10 with 72 SMs), 128 splits may over-partition. Non-blocking since the target is A100.

7. Test tolerancetest_triton_sdpa_splitk.py

All tests use max_abs_error < 0.05, which is reasonable for bf16 SDPA (expected ~1e-2 to 3e-2 range). The cross-validation test against the tiled SDPA kernel (test_cross_validate_with_sdpa) also uses 0.05 — this is good since both paths accumulate differently.

8. torch.cond dispatch in modelmodel.py:276-281

The torch.cond dispatch is clean:

y = torch.cond(
    q.shape[2] == 1,
    lambda q, k, v, mask: sdpa_decode_splitk(q, k, v, attn_mask=mask),
    lambda q, k, v, mask: sdpa(q, k, v, attn_mask=mask, enable_gqa=True),
    [q, k, v, attn_mask],
)

Note that during AOTI tracing, both branches are traced with symbolic shapes, so sdpa_decode_splitk correctly guards its L_q==1 assertion with isinstance(L_q, int) (line 1404). This is a good pattern.

9. Export shape changeexport.py:381-383

Changing example_seq_len to config.max_seq_len - 1 is the right fix. The previous smaller example caused AOTI to specialize intermediate buffers for small sizes, capping effective prefill at 64 tokens. The comment explaining this is helpful.


Summary

The kernel implementation is correct and well-tested. The main actionable item is issue #1 (missing CUDA/bf16 validation in sdpa_decode_splitk), which is a real gap that could cause confusing failures. Issue #2 (BLOCK_G over-allocation) is a perf optimization opportunity. Issue #4 is a trivial typo fix.


@digantdesai digantdesai force-pushed the digantdesai/sdpa-bench-and-perf-stats branch from ebe61e8 to 5d3b620 Compare April 10, 2026 04:37
Copilot AI review requested due to automatic review settings April 10, 2026 04:59
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread examples/models/qwen3_5_moe/main.cpp Outdated
Comment on lines +90 to +93
printf(
"\n\tPrompt Tokens: %" PRIu64 " Generated Tokens: %" PRIu64,
stats.num_prompt_tokens,
stats.num_generated_tokens);
Comment on lines +240 to +248
for name, label, _ in backends:
if name == ref_name or outputs[name] is None:
continue
err = _max_abs_error(outputs[name], ref_out)
assert err < 1e-2, (
f"Output mismatch for {_shape_label(shape)}: "
f"{label} vs {BACKENDS[ref_name][0]}, "
f"max abs error {err:.3e} >= 1e-2"
)
out = self.splitk(q, k, v, attn_mask=mask)

self.assertFalse(torch.isnan(out).any(), "All-masked should not NaN")
self.assertFalse(torch.isinf(out).any(), "All-masked should not Inf")
@digantdesai digantdesai force-pushed the digantdesai/sdpa-bench-and-perf-stats branch from 62428be to f011e54 Compare April 10, 2026 19:38
@Gasoonjia
Copy link
Copy Markdown
Contributor

Can you also list the prefill performance in the benchmark result?

Copilot AI review requested due to automatic review settings April 10, 2026 21:02
@digantdesai digantdesai force-pushed the digantdesai/sdpa-bench-and-perf-stats branch from f011e54 to 3836bea Compare April 10, 2026 21:02
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

self.assertFalse(torch.isnan(out).any(), "NaN in output")
self.assertLess(
_max_abs_error(out, ref),
0.05,
Copy link
Copy Markdown
Contributor

@Gasoonjia Gasoonjia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for the PR!
Please add test for split sdpa, and it would be great if we can confirm that prefill performance is not impacted.

also, what's the difference between original and baseline, in the perf comparison table?

key=["Lk", "HEAD_DIM", "NUM_GROUPS", "HAS_MASK"],
)
@triton.jit
def _sdpa_decode_splitk_kernel(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decode is not related to the sdpa kernel. Consider renaming it with _sdpa_splitk_kernel. Same as others.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kernel only works with Lq == 1. Hence the decode in the name, and assert in the check_args.

Comment thread examples/models/qwen3_5_moe/model.py Outdated
# The export produces two methods — decode (T=1, static) and
# prefill (T>=2, dynamic). Each traces only one branch, so no
# torch.cond is needed and we avoid GPU→CPU sync overhead.
if T == 1:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe in another PR, but im thinking if we should apply the new split_k sdpa only on T == 1.

IIUC, the core idea of split_k algo is trying to fully leverage the compute unit for GPU. Given the circumstance that

  1. batch size will always be 1 in our usage
  2. 108 SM kernels in A100

maybe we can apply the split_k sdpa even in

what if we can use torch.cond here and make runtime dynamic choose the right kernel?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow. the split-k kernel is for decode where the existing kernel doesn't work well. It is working well for prefill case with T > 1.

@classmethod
def setUpClass(cls):
_skip_if_no_cuda()
cls.sdpa = _import_sdpa()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right now we only test regular sdpa but not the split version. Please add a test for it.

# LICENSE file in the root directory of this source tree.

"""
Benchmark the Triton SDPA kernel against PyTorch SDPA backends.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe on another PR, but we can make it as a perf-ci to guard the perf.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can but I don't rely on the CI perf to be stable.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think some up and down should be fine, as long as the perf change is not too large.

@digantdesai
Copy link
Copy Markdown
Contributor Author

what's the difference between original and baseline, in the perf comparison table?

In the PR summary.

@digantdesai digantdesai force-pushed the digantdesai/sdpa-bench-and-perf-stats branch from 0a46be3 to 0c0f132 Compare April 13, 2026 16:25
Copilot AI review requested due to automatic review settings April 13, 2026 16:25
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

# prefill (T>=2, dynamic). Each traces only one branch, so no
# torch.cond is needed and we avoid GPU→CPU sync overhead.
if T == 1:
y = sdpa_decode_splitk(q, k, v, attn_mask=attn_mask)
self.assertFalse(torch.isnan(out).any(), "NaN in output")
self.assertLess(
_max_abs_error(out, ref),
0.05,
Comment on lines +1411 to +1414
# is_causal is a no-op at L_q=1 (single query can't attend to future
# positions), so we accept it silently for API compatibility with callers
# that always pass is_causal=True for decode.

Comment on lines +1398 to +1403
"""Split-K flash-decoding SDPA for L_q=1 (decode step).

Signature mirrors sdpa() for drop-in use with torch.cond dispatch.
enable_gqa is accepted but ignored — GQA is handled natively via
H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1.
"""
@digantdesai digantdesai force-pushed the digantdesai/sdpa-bench-and-perf-stats branch from 0c0f132 to 0609ae2 Compare April 13, 2026 16:35
Copilot AI review requested due to automatic review settings April 14, 2026 02:05
@digantdesai digantdesai force-pushed the digantdesai/sdpa-bench-and-perf-stats branch from 0609ae2 to 4de4538 Compare April 14, 2026 02:05
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1398 to +1403
"""Split-K flash-decoding SDPA for L_q=1 (decode step).

Signature mirrors sdpa() for drop-in use with torch.cond dispatch.
enable_gqa is accepted but ignored — GQA is handled natively via
H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1.
"""
Comment on lines 53 to 55
max_seq_len: int = 4096
use_splitk_decode: bool = True
layer_types: list = field(default_factory=list)
self.assertFalse(torch.isnan(out).any(), "NaN in output")
self.assertLess(
_max_abs_error(out, ref),
0.05,
Comment on lines +1411 to +1422
# is_causal is a no-op at L_q=1 (single query can't attend to future
# positions), so we accept it silently for API compatibility with callers
# that always pass is_causal=True for decode.

# Validation — only check at runtime (concrete shapes), not during AOTI
# tracing where shapes are symbolic. torch.cond traces both branches with
# the same symbolic L_q, so L_q is not necessarily 1 during tracing.
if isinstance(L_q, int):
if L_q != 1:
raise RuntimeError(
f"sdpa_decode_splitk requires L_q == 1 (decode); got L_q={L_q}"
)
@digantdesai digantdesai force-pushed the digantdesai/sdpa-bench-and-perf-stats branch from 4de4538 to 14bd4cb Compare April 14, 2026 02:56
Copilot AI review requested due to automatic review settings April 14, 2026 14:50
@digantdesai digantdesai force-pushed the digantdesai/sdpa-bench-and-perf-stats branch from 14bd4cb to 3069c79 Compare April 14, 2026 14:50
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread examples/models/qwen3_5_moe/model.py Outdated
import torch
import torch.nn as nn

from executorch.backends.cuda.triton.kernels.sdpa import sdpa, sdpa_decode_splitk
Comment on lines +1391 to +1414
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float = 0.0,
enable_gqa: bool = False,
) -> torch.Tensor:
"""Split-K flash-decoding SDPA for L_q=1 (decode step).

Signature mirrors sdpa() for drop-in use with torch.cond dispatch.
enable_gqa is accepted but ignored — GQA is handled natively via
H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1.
"""
_validate_sdpa_inputs(query, key, value, dropout_p, enable_gqa)

B, H_q, L_q, D = query.shape
_, H_kv, L_kv, _ = key.shape

out = torch.empty((B, H_q, L_q, D), device=query.device, dtype=query.dtype)

# is_causal is a no-op at L_q=1 (single query can't attend to future
# positions), so we accept it silently for API compatibility with callers
# that always pass is_causal=True for decode.

Comment on lines +281 to +288
def test_is_causal_rejected(self):
"""is_causal=True should raise RuntimeError."""
B, H_q, H_kv, D = 1, 8, 2, 64
q = torch.randn(B, H_q, 1, D, dtype=torch.bfloat16, device="cuda")
k = torch.randn(B, H_kv, 64, D, dtype=torch.bfloat16, device="cuda")
v = torch.randn(B, H_kv, 64, D, dtype=torch.bfloat16, device="cuda")
with self.assertRaises(RuntimeError):
self.splitk(q, k, v, is_causal=True)
@digantdesai digantdesai force-pushed the digantdesai/sdpa-bench-and-perf-stats branch from 3069c79 to 4d08bf0 Compare April 14, 2026 18:10
Copilot AI review requested due to automatic review settings April 15, 2026 01:33
@digantdesai digantdesai force-pushed the digantdesai/sdpa-bench-and-perf-stats branch from 4d08bf0 to f02b19a Compare April 15, 2026 01:33
Compares ET Triton SDPA (native GQA) against PyTorch Flash/Efficient/Math
backends (expanded KV) across Lk=64..16K on A100. Uses triton.testing.do_bench
for timing. Standalone script, no changes to the kernel.

This PR was authored with the assistance of Claude
Register `triton::sdpa_decode_splitk` as an independent op so AOTI
can trace and compile it without the runtime L_kv conditional that
prevents the split-K path from appearing in the standard `sdpa` op.

The split-K (flash-decoding) approach partitions the KV sequence
across CTAs and reduces partial softmax results in a second kernel.
The benchmark script now includes the split-K column for comparison.

BLOCK_G (the GQA group tile) uses _next_power_of_2_unclamped() to
avoid inflating small group counts to 16. Phantom rows from
over-sized tiles change register pressure and instruction scheduling,
altering fp32 accumulation order enough to degrade output quality
over long autoregressive sequences.

Standalone kernel benchmark on H100 (Qwen3.5 MoE decode, B=1, H_q=16,
H_kv=2, D=256, bf16):

  Lk       ET Tiled (us)  ET Split-K (us)  Speedup
  64            131.8          259.5         0.5x
  512            98.9          221.5         0.4x
  4096          199.9          214.4         0.9x
  8192          392.2          211.3         1.9x
  16384         775.3          211.8         3.7x

Split-K breaks even around Lk=4096 and dominates at longer sequences
where the tiled kernel's single-CTA-per-head bottleneck becomes severe.

This PR was authored with the assistance of Claude
The previous example used T=2, which caused AOTI to compile the
chunk_gated_delta_rule kernel for a single chunk (NT=1). At runtime,
prompts longer than 64 tokens (requiring NT>1 chunks) failed with
"Error resizing tensor at input 0". Using max_seq_len-1 as the
example ensures AOTI generalizes intermediate buffer sizes for the
full sequence length range.

Comparison against original export (tq4_sdpa fused kernel)
on H100 (Qwen3.5-35B-A3B, HQQ-INT4, max_seq_len=4096, 5 runs median):

                Original (tq4_sdpa)  Baseline (Triton SDPA)
  Decode tok/s       68.4               61.7
  Prefill tok/s     275.7              378.2

Baseline prefill is 1.37x faster; decode is 0.90x (tq4_sdpa's fused
decode kernel is faster than the tiled Triton SDPA at L_q=1). The
split-K commit addresses the decode gap.

This PR was authored with the assistance of Claude
Dual-method export (decode T=1, prefill T>=2) lets the model use a
simple if/else on T instead of torch.cond, eliminating the GPU-to-CPU
sync overhead that torch.cond's predicate evaluation requires.

Decode calls sdpa_decode_splitk (split-K flash-decoding for high KV
occupancy), prefill calls tiled sdpa. Guard sdpa_decode_splitk
validation behind isinstance(L_q, int) so AOTI tracing with symbolic
shapes doesn't trip the L_q==1 check.

Align sdpa_decode_splitk signature with sdpa (dropout_p, is_causal,
enable_gqa) for consistent API; unsupported args fail with clear
messages.

This PR was authored with the assistance of Claude
Add `use_splitk_decode` config flag to control whether FullAttention
uses the split-K (flash-decoding) SDPA kernel or the tiled SDPA for
decode (T=1). The split-K kernel partitions the KV sequence across
CTAs, yielding ~20% higher decode throughput on H100:

  Variant          Decode tok/s (avg across prompts)
  Tiled SDPA       88.5
  Split-K SDPA     107.5   (+21%)

The flag defaults to True (split-K on). Pass `--no-splitk` at export
time to disable. Quality is verified identical at temperature=0.

This PR was authored with the assistance of Claude
@digantdesai digantdesai force-pushed the digantdesai/sdpa-bench-and-perf-stats branch from 1af2029 to d5209fc Compare April 15, 2026 01:38
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

self.assertFalse(torch.isnan(out).any(), "NaN in output")
self.assertLess(
_max_abs_error(out, ref),
0.05,
@digantdesai digantdesai merged commit 87e65ac into main Apr 15, 2026
219 of 222 checks passed
@digantdesai digantdesai deleted the digantdesai/sdpa-bench-and-perf-stats branch April 15, 2026 02:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/cuda CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants